{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用SVM检测蘑菇🍄是否有毒\n",
    "\n",
    "需要在标有TODO的地方添加代码.\n",
    "\n",
    "使用的数据集: https://archive.ics.uci.edu/ml/datasets/mushroom\n",
    " \n",
    "数据集中每一条数据包含如下特征, 特征包括对蘑菇形状, 质地, 色彩等特征的描述, 我们需要以此判断🍄是否有毒(p)或者可以吃(e). 为两个类别的分类问题.\n",
    "\n",
    "1. cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s \n",
    "2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s \n",
    "3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r, pink=p,purple=u,red=e,white=w,yellow=y \n",
    "4. bruises?: bruises=t,no=f \n",
    "5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s \n",
    "6. gill-attachment: attached=a,descending=d,free=f,notched=n \n",
    "7. gill-spacing: close=c,crowded=w,distant=d \n",
    "8. gill-size: broad=b,narrow=n \n",
    "9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e, white=w,yellow=y \n",
    "10. stalk-shape: enlarging=e,tapering=t \n",
    "11. stalk-root: bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,missing=? \n",
    "12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s \n",
    "13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s \n",
    "14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y \n",
    "15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y \n",
    "16. veil-type: partial=p,universal=u \n",
    "17. veil-color: brown=n,orange=o,white=w,yellow=y \n",
    "18. ring-number: none=n,one=o,two=t \n",
    "19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l, none=n,pendant=p,sheathing=s,zone=z \n",
    "20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r, orange=o,purple=u,white=w,yellow=y \n",
    "21. population: abundant=a,clustered=c,numerous=n, scattered=s,several=v,solitary=y \n",
    "22. habitat: grasses=g,leaves=l,meadows=m,paths=p, urban=u,waste=w,woods=d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class</th>\n",
       "      <th>cap-shape</th>\n",
       "      <th>cap-surface</th>\n",
       "      <th>cap-color</th>\n",
       "      <th>bruises</th>\n",
       "      <th>odor</th>\n",
       "      <th>gill-attachment</th>\n",
       "      <th>gill-spacing</th>\n",
       "      <th>gill-size</th>\n",
       "      <th>gill-color</th>\n",
       "      <th>...</th>\n",
       "      <th>stalk-surface-below-ring</th>\n",
       "      <th>stalk-color-above-ring</th>\n",
       "      <th>stalk-color-below-ring</th>\n",
       "      <th>veil-type</th>\n",
       "      <th>veil-color</th>\n",
       "      <th>ring-number</th>\n",
       "      <th>ring-type</th>\n",
       "      <th>spore-print-color</th>\n",
       "      <th>population</th>\n",
       "      <th>habitat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>p</td>\n",
       "      <td>x</td>\n",
       "      <td>s</td>\n",
       "      <td>n</td>\n",
       "      <td>t</td>\n",
       "      <td>p</td>\n",
       "      <td>f</td>\n",
       "      <td>c</td>\n",
       "      <td>n</td>\n",
       "      <td>k</td>\n",
       "      <td>...</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>w</td>\n",
       "      <td>p</td>\n",
       "      <td>w</td>\n",
       "      <td>o</td>\n",
       "      <td>p</td>\n",
       "      <td>k</td>\n",
       "      <td>s</td>\n",
       "      <td>u</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>e</td>\n",
       "      <td>x</td>\n",
       "      <td>s</td>\n",
       "      <td>y</td>\n",
       "      <td>t</td>\n",
       "      <td>a</td>\n",
       "      <td>f</td>\n",
       "      <td>c</td>\n",
       "      <td>b</td>\n",
       "      <td>k</td>\n",
       "      <td>...</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>w</td>\n",
       "      <td>p</td>\n",
       "      <td>w</td>\n",
       "      <td>o</td>\n",
       "      <td>p</td>\n",
       "      <td>n</td>\n",
       "      <td>n</td>\n",
       "      <td>g</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>e</td>\n",
       "      <td>b</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>t</td>\n",
       "      <td>l</td>\n",
       "      <td>f</td>\n",
       "      <td>c</td>\n",
       "      <td>b</td>\n",
       "      <td>n</td>\n",
       "      <td>...</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>w</td>\n",
       "      <td>p</td>\n",
       "      <td>w</td>\n",
       "      <td>o</td>\n",
       "      <td>p</td>\n",
       "      <td>n</td>\n",
       "      <td>n</td>\n",
       "      <td>m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>p</td>\n",
       "      <td>x</td>\n",
       "      <td>y</td>\n",
       "      <td>w</td>\n",
       "      <td>t</td>\n",
       "      <td>p</td>\n",
       "      <td>f</td>\n",
       "      <td>c</td>\n",
       "      <td>n</td>\n",
       "      <td>n</td>\n",
       "      <td>...</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>w</td>\n",
       "      <td>p</td>\n",
       "      <td>w</td>\n",
       "      <td>o</td>\n",
       "      <td>p</td>\n",
       "      <td>k</td>\n",
       "      <td>s</td>\n",
       "      <td>u</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>e</td>\n",
       "      <td>x</td>\n",
       "      <td>s</td>\n",
       "      <td>g</td>\n",
       "      <td>f</td>\n",
       "      <td>n</td>\n",
       "      <td>f</td>\n",
       "      <td>w</td>\n",
       "      <td>b</td>\n",
       "      <td>k</td>\n",
       "      <td>...</td>\n",
       "      <td>s</td>\n",
       "      <td>w</td>\n",
       "      <td>w</td>\n",
       "      <td>p</td>\n",
       "      <td>w</td>\n",
       "      <td>o</td>\n",
       "      <td>e</td>\n",
       "      <td>n</td>\n",
       "      <td>a</td>\n",
       "      <td>g</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 23 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  class cap-shape cap-surface cap-color bruises odor gill-attachment  \\\n",
       "0     p         x           s         n       t    p               f   \n",
       "1     e         x           s         y       t    a               f   \n",
       "2     e         b           s         w       t    l               f   \n",
       "3     p         x           y         w       t    p               f   \n",
       "4     e         x           s         g       f    n               f   \n",
       "\n",
       "  gill-spacing gill-size gill-color  ... stalk-surface-below-ring  \\\n",
       "0            c         n          k  ...                        s   \n",
       "1            c         b          k  ...                        s   \n",
       "2            c         b          n  ...                        s   \n",
       "3            c         n          n  ...                        s   \n",
       "4            w         b          k  ...                        s   \n",
       "\n",
       "  stalk-color-above-ring stalk-color-below-ring veil-type veil-color  \\\n",
       "0                      w                      w         p          w   \n",
       "1                      w                      w         p          w   \n",
       "2                      w                      w         p          w   \n",
       "3                      w                      w         p          w   \n",
       "4                      w                      w         p          w   \n",
       "\n",
       "  ring-number ring-type spore-print-color population habitat  \n",
       "0           o         p                 k          s       u  \n",
       "1           o         p                 n          n       g  \n",
       "2           o         p                 n          n       m  \n",
       "3           o         p                 k          s       u  \n",
       "4           o         e                 n          a       g  \n",
       "\n",
       "[5 rows x 23 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# 导入数据\n",
    "mush_df = pd.read_csv('mushrooms.csv')\n",
    "\n",
    "# 将值从字母转换为\n",
    "mush_df_encoded = pd.get_dummies(mush_df)\n",
    "\n",
    "mush_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class_e</th>\n",
       "      <th>class_p</th>\n",
       "      <th>cap-shape_b</th>\n",
       "      <th>cap-shape_c</th>\n",
       "      <th>cap-shape_f</th>\n",
       "      <th>cap-shape_k</th>\n",
       "      <th>cap-shape_s</th>\n",
       "      <th>cap-shape_x</th>\n",
       "      <th>cap-surface_f</th>\n",
       "      <th>cap-surface_g</th>\n",
       "      <th>...</th>\n",
       "      <th>population_s</th>\n",
       "      <th>population_v</th>\n",
       "      <th>population_y</th>\n",
       "      <th>habitat_d</th>\n",
       "      <th>habitat_g</th>\n",
       "      <th>habitat_l</th>\n",
       "      <th>habitat_m</th>\n",
       "      <th>habitat_p</th>\n",
       "      <th>habitat_u</th>\n",
       "      <th>habitat_w</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 119 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   class_e  class_p  cap-shape_b  cap-shape_c  cap-shape_f  cap-shape_k  \\\n",
       "0        0        1            0            0            0            0   \n",
       "1        1        0            0            0            0            0   \n",
       "2        1        0            1            0            0            0   \n",
       "3        0        1            0            0            0            0   \n",
       "4        1        0            0            0            0            0   \n",
       "\n",
       "   cap-shape_s  cap-shape_x  cap-surface_f  cap-surface_g  ...  population_s  \\\n",
       "0            0            1              0              0  ...             1   \n",
       "1            0            1              0              0  ...             0   \n",
       "2            0            0              0              0  ...             0   \n",
       "3            0            1              0              0  ...             1   \n",
       "4            0            1              0              0  ...             0   \n",
       "\n",
       "   population_v  population_y  habitat_d  habitat_g  habitat_l  habitat_m  \\\n",
       "0             0             0          0          0          0          0   \n",
       "1             0             0          0          1          0          0   \n",
       "2             0             0          0          0          0          1   \n",
       "3             0             0          0          0          0          0   \n",
       "4             0             0          0          1          0          0   \n",
       "\n",
       "   habitat_p  habitat_u  habitat_w  \n",
       "0          0          1          0  \n",
       "1          0          0          0  \n",
       "2          0          0          0  \n",
       "3          0          1          0  \n",
       "4          0          0          0  \n",
       "\n",
       "[5 rows x 119 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mush_df_encoded.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将特征和类别标签分布赋值给 X 和 y\n",
    "X_mush = mush_df_encoded.iloc[:,2:]\n",
    "y_mush = mush_df_encoded.iloc[:,1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练SVM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 建立寻pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "# TODO\n",
    "pca = PCA(n_components=20, whiten=True, random_state=41)#发现n_components在一定范围内维数越高准确率也就越高\n",
    "svc = SVC(kernel='linear', class_weight='balanced')\n",
    "model = make_pipeline(pca, svc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 将数据分为训练和测试数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "Xtrain, Xtest, ytrain, ytest = train_test_split(X_mush, y_mush,\n",
    "                                                random_state=41)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调参:通过交叉验证寻找最佳的 C (控制间隔的大小)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda2/envs/py3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
      "  warnings.warn(CV_WARNING, FutureWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 23.7 s, sys: 0 ns, total: 23.7 s\n",
      "Wall time: 23.5 s\n",
      "{'svc__C': 50}\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "# TODO\n",
    "param_grid = {'svc__C': [1, 5, 10, 50]}\n",
    "grid = GridSearchCV(model, param_grid)\n",
    "\n",
    "%time grid.fit(Xtrain, ytrain)\n",
    "print(grid.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用训练好的SVM做预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正确率: 0.9758739537173806\n"
     ]
    }
   ],
   "source": [
    "# TODO\n",
    "from sklearn.metrics import accuracy_score\n",
    "model = grid.best_estimator_\n",
    "yfit = model.predict(Xtest)\n",
    "print(\"正确率:\",accuracy_score(ytest, yfit))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 生成性能报告"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.98      0.97      0.98      1047\n",
      "           1       0.97      0.98      0.98       984\n",
      "\n",
      "    accuracy                           0.98      2031\n",
      "   macro avg       0.98      0.98      0.98      2031\n",
      "weighted avg       0.98      0.98      0.98      2031\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(ytest, yfit))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "coursera": {
   "course_slug": "python-machine-learning",
   "graded_item_id": "eWYHL",
   "launcher_item_id": "BAqef",
   "part_id": "fXXRp"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
